{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "version": 1,
      "views": {
       "grid_default": {},
       "report_default": {
        "hidden": false
       }
      }
     }
    }
   },
   "source": [
    "# Example: QAT Resnet18 Optimization\n",
    "\n",
    "This is the notebook described in the [TREx blog post](https://developer.nvidia.com/blog/exploring-tensorrt-engines-with-trex)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "version": 1,
      "views": {
       "grid_default": {},
       "report_default": {
        "hidden": false
       }
      }
     }
    }
   },
   "source": [
    "## Load JSON Files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "extensions": {
     "jupyter_dashboards": {
      "version": 1,
      "views": {
       "grid_default": {},
       "report_default": {
        "hidden": true
       }
      }
     }
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import IPython\n",
    "from ipywidgets import widgets\n",
    "from trex import *\n",
    "\n",
    "# Configure a wider output (for the wide graphs)\n",
    "set_wide_display()\n",
    "\n",
    "def create_plan(engine_path: str, engine_name):\n",
    "    plan = EnginePlan(\n",
    "        f'{engine_path}.graph.json',\n",
    "        f'{engine_path}.profile.json',\n",
    "        f\"{engine_path}.profile.metadata.json\",\n",
    "        name = engine_name)\n",
    "    return plan\n",
    "\n",
    "\n",
    "ipynb_path = os.path.dirname(os.path.realpath(\"__file__\"))\n",
    "\n",
    "examples_dir = os.path.join(ipynb_path, \"A100\")\n",
    "engines_info = {\n",
    "    \"ResNet18-FP32\": \"fp32/resnet.onnx.engine\",\n",
    "    \"ResNet18-FP16\": \"fp16/resnet.onnx.engine\",\n",
    "    \"ResNet18-QAT\":  \"qat/resnet-qat.onnx.engine\",\n",
    "    \"ResNet18-QAT+Residual\": \"qat-residual/resnet-qat-residual.onnx.engine\",\n",
    "    \"ResNet18-QAT+Residual+GAP\": \"qat-residual-qgap/resnet-qat-residual-qgap.onnx.engine\",\n",
    "}\n",
    "\n",
    "\n",
    "plans = [create_plan(os.path.join(examples_dir, engine_path), engine_name) for  engine_name, engine_path in engines_info.items()]\n",
    "compare_engines_summaries_tbl(plans, orientation='vertical')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "compare_engines_overview(plans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "compare_engines_layer_latencies(\n",
    "    plans[1], plans[3],\n",
    "    # Allow for 3% error grace threshold when color highlighting performance differences\n",
    "    threshold=0.03,\n",
    "    # Inexact matching uses only the layer's first input and output to match to other layers.\n",
    "    exact_matching=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for plan in plans:\n",
    "    graph = to_dot(plan, layer_type_formatter, display_regions=True, expand_layer_details=True)\n",
    "    render_dot(graph, plan.name, 'svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "extensions": {
   "jupyter_dashboards": {
    "activeView": "report_default",
    "version": 1,
    "views": {
     "grid_default": {
      "name": "grid",
      "type": "grid"
     },
     "report_default": {
      "name": "report",
      "type": "report"
     }
    }
   }
  },
  "interpreter": {
   "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
